{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Copyright 2018 Google LLC  \n",
    "  \n",
    " Licensed under the Apache License, Version 2.0 (the \"License\");  \n",
    " you may not use this file except in compliance with the License.  \n",
    " You may obtain a copy of the License at  \n",
    "  \n",
    "     http://www.apache.org/licenses/LICENSE-2.0  \n",
    "  \n",
    " Unless required by applicable law or agreed to in writing, software  \n",
    " distributed under the License is distributed on an \"AS IS\" BASIS,  \n",
    " WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  \n",
    " See the License for the specific language governing permissions and  \n",
    " limitations under the License."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the dependency .py files, if any.\n",
    "! git clone https://github.com/GoogleCloudPlatform/cloudml-samples.git\n",
    "! cp cloudml-samples/tpu/templates/tpu_lstm_keras/* .\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import argparse\n",
    "import os\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from tensorflow.keras.layers import LSTM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_model():\n",
    "    inputs = tf.keras.Input(shape=(5, 3))\n",
    "    encoded = tf.keras.layers.LSTM(10)(inputs)\n",
    "    outputs = tf.keras.layers.Dense(1, activation=tf.nn.sigmoid)(encoded)\n",
    "\n",
    "    model = tf.keras.Model(inputs=inputs, outputs=outputs)\n",
    "\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_input_fn():\n",
    "    batch_size = 16\n",
    "\n",
    "    # make some fake data\n",
    "    x = np.random.rand(100, 5, 3)\n",
    "    y = np.random.rand(100, 1)\n",
    "\n",
    "    # TPUs currently do not support float64\n",
    "    x_tensor = tf.constant(x, dtype=tf.float32)\n",
    "    y_tensor = tf.constant(y, dtype=tf.float32)\n",
    "\n",
    "    # create tf.data.Dataset\n",
    "    dataset = tf.data.Dataset.from_tensor_slices((x_tensor, y_tensor))\n",
    "\n",
    "    dataset = dataset.repeat().shuffle(32).batch(batch_size, drop_remainder=True)\n",
    "\n",
    "    # TPUs need to know all dimensions when the graph is built\n",
    "    # Datasets know the batch size only when the graph is run\n",
    "    def set_shapes(features, labels):\n",
    "        features_shape = features.get_shape().merge_with([batch_size, None, None])\n",
    "        labels_shape = labels.get_shape().merge_with([batch_size, None])\n",
    "\n",
    "        features.set_shape(features_shape)\n",
    "        labels.set_shape(labels_shape)\n",
    "\n",
    "        return features, labels\n",
    "\n",
    "    dataset = dataset.map(set_shapes)\n",
    "    dataset = dataset.prefetch(tf.contrib.data.AUTOTUNE)\n",
    "\n",
    "    return dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def main(args):\n",
    "    model = build_model()\n",
    "\n",
    "    if args.use_tpu:\n",
    "        # distribute over TPU cores\n",
    "        # Note: This requires TensorFlow 1.11\n",
    "        tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(args.tpu)\n",
    "        strategy = tf.contrib.tpu.TPUDistributionStrategy(tpu_cluster_resolver)\n",
    "        model = tf.contrib.tpu.keras_to_tpu_model(\n",
    "            model, strategy=strategy)\n",
    "\n",
    "    optimizer = tf.train.RMSPropOptimizer(learning_rate=0.05)\n",
    "    loss_fn = tf.losses.log_loss\n",
    "    model.compile(optimizer, loss_fn)\n",
    "\n",
    "    model.fit(train_input_fn, epochs=3, steps_per_epoch=10)\n",
    "\n",
    "    if not os.path.exists(args.model_dir):\n",
    "        os.makedirs(args.model_dir)\n",
    "    model.save(os.path.join(args.model_dir, 'model.hd5'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "parser = argparse.ArgumentParser()\n",
    "\n",
    "parser.add_argument(\n",
    "    '--model-dir',\n",
    "    type=str,\n",
    "    default='/tmp/tpu-template',\n",
    "    help='Location to write checkpoints and summaries to.  Must be a GCS URI when using Cloud TPU.')\n",
    "parser.add_argument(\n",
    "    '--use-tpu',\n",
    "    action='store_true',\n",
    "    help='Whether to use TPU.')\n",
    "parser.add_argument(\n",
    "    '--tpu',\n",
    "    default=None,\n",
    "    help='The name or GRPC URL of the TPU node.  Leave it as `None` when training on CMLE.')\n",
    "\n",
    "args, _ = parser.parse_known_args()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# TODO(user): change this\n",
    "args.model_dir = 'gs://your-gcs-bucket'\n",
    "\n",
    "# Get hostname from environment using ipython magic.\n",
    "# This returns a list.\n",
    "hostname = !hostname\n",
    "\n",
    "args.tpu = hostname[0]\n",
    "args.use_tpu = True\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Use gcloud command line tool to create a TPU in the same zone as the VM instance.\n",
    "! gcloud compute tpus create `hostname` \\\n",
    "  --zone `gcloud compute instances list --filter=\"name=$(hostname)\" --format 'csv[no-heading](zone)'`\\\n",
    "  --network default \\\n",
    "  --range 10.101.1.0 \\\n",
    "  --version 1.13\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "main(args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Use gcloud command line tool to delete the TPU.\n",
    "! gcloud compute tpus delete `hostname` \\\n",
    "  --zone `gcloud compute instances list --filter=\"name=$(hostname)\" --format 'csv[no-heading](zone)'`\\\n",
    "  --quiet\n"
   ]
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 2
}
